#!/usr/bin/env python3

# (c) Copyright 2018-2019 CORSIKA Project, corsika-project@lists.kit.edu
#
# See file AUTHORS for a list of contributors.
#
# This software is distributed under the terms of the 3-clause BSD license.
# See file LICENSE for a full version of the license.

import pickle, sys, itertools



def load_particledb(filename):
    '''
    loads the pickled particle_db (which is an OrderedDict)
    '''
    with open(filename, "rb") as f:
        particle_db = pickle.load(f)
    return particle_db



def read_epos_codes(filename, particle_db):
    '''
    reads to epos codes data file

    For particls known to epos, add 'epos_code' and 'epos_xsType' to particle_db
    '''
    with open(filename) as f:
        for line in f:
            line = line.strip()
            if len(line)==0 or line[0] == '#':
                continue            
            identifier, epos_code, canInteractFlag, xsType = line.split()
            try:
                particle_db[identifier]["epos_code"] = int(epos_code)
                particle_db[identifier]["epos_xsType"] = xsType
            except KeyError as e:
                raise Exception("Identifier '{:s}' not found in CORSIKA8 particle_db".format(identifier))

            
def set_default_epos_definition(particle_db):
    '''
    Also particles not explicitly known by QGSJetII may in fact interact via mapping 
    to cross section types (xsType) and hadron type (hadronType)

    This is achieved here.

    The function returns nothing, but modified the input particle_db by adding the 
    fields 'xsType' and 'hadronType'
    '''
    for identifier, pData in particle_db.items():
        # the cross-section types
        xsType = "CannotInteract"
        hadronType = "UndefinedType"
        if (pData['isNucleus']):
            xsType = "Baryon"
            hadronType = "NucleusType"
            
            pData['epos_xsType'] = xsType
            pData['epos_hadronType'] = hadronType


def generate_epos_enum(particle_db):
    '''
     generates the enum to access epos particles by readable names
    '''
    output = "enum class EposCode : int32_t {\n"
    for identifier, pData in particle_db.items():
        if 'epos_code' in pData:
            output += "  {:s} = {:d},\n".format(identifier, pData['epos_code'])
    output += "};\n"
    return output



def generate_corsika2epos(particle_db):    
    '''
    generates the look-up table to convert corsika codes to epos codes
    '''
    string = "std::array<EposCode, {:d}> constexpr corsika2epos = {{\n".format(len(particle_db))
    for identifier, pData in particle_db.items():
        if pData['isNucleus']: continue
        if 'epos_code' in pData:
            string += "  EposCode::{:s}, \n".format(identifier)
        else:
            string += "  EposCode::Unknown, // {:s}\n".format(identifier + ' not implemented in EPOS')
    string += "};\n"
    return string
    


def generate_corsika2epos_xsType(particle_db):    
    '''
    generates the look-up table to convert corsika codes to epos codes
    '''
    string = "std::array<EposXSClass, {:d}> constexpr corsika2eposXStype = {{\n".format(len(particle_db))
    for identifier, pData in particle_db.items():
        if pData['isNucleus']: continue
        if 'epos_xsType' in pData:
            string += "  EposXSClass::{:s}, // {:s}\n".format(pData['epos_xsType'], identifier)
        else:
            string += "  EposXSClass::CannotInteract, // {:s}\n".format(identifier + ' not implemented in EPOS')
    string += "};\n"
    return string


def generate_epos2corsika(particle_db) :
    '''
    generates the look-up table to convert epos codes to corsika codes    
    '''
    string = ""
    
    minID = 0
    for identifier, pData in particle_db.items() :
        if 'epos_code' in pData:
            minID = min(minID, pData['epos_code'])
    string += "EposCodeIntType constexpr minEpos = {:d};\n\n".format(minID)

    pDict = {}
    for identifier, pData in particle_db.items() :
        if 'epos_code' in pData:
            sib_code = pData['epos_code'] - minID
            pDict[sib_code] = identifier
    
    nPart = max(pDict.keys()) - min(pDict.keys()) + 1
    string += "std::array<corsika::Code, {:d}> constexpr epos2corsika = {{\n".format(nPart)
    
    for iPart in range(nPart) :
        if iPart in pDict:
            identifier = pDict[iPart]
        else:
            identifier = "Unknown"
        string += "  corsika::Code::{:s}, \n".format(identifier)
    
    string += "};\n"
    return string

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("usage: {:s} <particle_db.pkl> <epos_codes.dat>".format(sys.argv[0]), file=sys.stderr)
        sys.exit(1)
        
    print("code_generator.py for EPOS")
    
    particle_db = load_particledb(sys.argv[1])
    read_epos_codes(sys.argv[2], particle_db)
    set_default_epos_definition(particle_db)
    
    with open("Generated.inc", "w") as f:
        print("// this file is automatically generated\n// edit at your own risk!\n", file=f)
        print(generate_epos_enum(particle_db), file=f)
        print(generate_corsika2epos(particle_db), file=f)
        print(generate_epos2corsika(particle_db), file=f)
        print(generate_corsika2epos_xsType(particle_db), file=f)
